# Copyright (C) 2012-2013 Claudio Guarnieri.
# Copyright (C) 2014-2018 Cuckoo Foundation.
# This file is part of Cuckoo Sandbox - http://www.cuckoosandbox.org
# See the file 'docs/LICENSE' for copying permission.

import base64
import binascii
import hashlib
import logging
import mmap
import os
import pefile
import re
import shutil
import tempfile
import zipfile

from cuckoo.common.whitelist import is_whitelisted_domain
from cuckoo.compat import magic

try:
    import pydeep
    HAVE_PYDEEP = True
except ImportError:
    HAVE_PYDEEP = False

log = logging.getLogger(__name__)

FILE_CHUNK_SIZE = 16*1024*1024

URL_REGEX = (
    # HTTP/HTTPS.
    "(https?:\\/\\/)"
    "((["
    # IP address.
    "(?:[0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\\."
    "(?:[0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\\."
    "(?:[0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\\."
    "(?:[0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])]|"
    # Or domain name.
    "[a-zA-Z0-9\\.-]+)"
    # Optional port.
    "(\\:\\d+)?"
    # URI.
    "(/[\\(\\)a-zA-Z0-9_:%?=/\\.-]*)?"
)

PUBPRIVKEY = (
    "("
    "(?:-----BEGIN PUBLIC KEY-----"
    "[a-zA-Z0-9\\n\\+/]+"
    "-----END PUBLIC KEY-----)"
    "|"
    "(?:-----BEGIN RSA PRIVATE KEY-----"
    "[a-zA-Z0-9\\n\\+/]+"
    "-----END RSA PRIVATE KEY-----)"
    ")"
)

class Dictionary(dict):
    """Cuckoo custom dict."""

    def __getattr__(self, key):
        return self.get(key, None)

    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

class URL:
    """URL base object."""

    def __init__(self, url):
        """@param url: URL"""
        self.url = url

class File(object):
    """Basic file object class with all useful utilities."""
    # Given that ssdeep hashes are not really used much in practice we're just
    # going to disable its warning by default for now.
    notified_pydeep = True

    # The yara rules should not change during one Cuckoo run and as such we're
    # caching 'em. This dictionary is filled during init_yara().
    yara_rules = {}

    def __init__(self, file_path, temporary=False):
        """@param file_path: file path."""
        self.file_path = file_path
        self.temporary = temporary

        # these will be populated when first accessed
        self._file_data = None
        self._crc32 = None
        self._md5 = None
        self._sha1 = None
        self._sha256 = None
        self._sha512 = None

    def __del__(self):
        self.temporary and os.unlink(self.file_path)

    def get_name(self):
        """Get file name.
        @return: file name.
        """
        file_name = os.path.basename(self.file_path)
        return file_name

    def valid(self):
        return os.path.exists(self.file_path) and \
            os.path.isfile(self.file_path) and \
            os.path.getsize(self.file_path) != 0

    def get_data(self):
        """Read file contents.
        @return: data.
        """
        return self.file_data

    def get_chunks(self):
        """Read file contents in chunks (generator)."""

        with open(self.file_path, "rb") as fd:
            while True:
                chunk = fd.read(FILE_CHUNK_SIZE)
                if not chunk:
                    break
                yield chunk

    def calc_hashes(self):
        """Calculate all possible hashes for this file."""
        crc = 0
        md5 = hashlib.md5()
        sha1 = hashlib.sha1()
        sha256 = hashlib.sha256()
        sha512 = hashlib.sha512()

        for chunk in self.get_chunks():
            crc = binascii.crc32(chunk, crc)
            md5.update(chunk)
            sha1.update(chunk)
            sha256.update(chunk)
            sha512.update(chunk)

        self._crc32 = "".join("%02X" % ((crc >> i) & 0xff)
                              for i in [24, 16, 8, 0])
        self._md5 = md5.hexdigest()
        self._sha1 = sha1.hexdigest()
        self._sha256 = sha256.hexdigest()
        self._sha512 = sha512.hexdigest()

    @property
    def file_data(self):
        if not self._file_data:
            self._file_data = open(self.file_path, "rb").read()
        return self._file_data

    def get_size(self):
        """Get file size.
        @return: file size.
        """
        return os.path.getsize(self.file_path)

    def get_crc32(self):
        """Get CRC32.
        @return: CRC32.
        """
        if not self._crc32:
            self.calc_hashes()
        return self._crc32

    def get_md5(self):
        """Get MD5.
        @return: MD5.
        """
        if not self._md5:
            self.calc_hashes()
        return self._md5

    def get_sha1(self):
        """Get SHA1.
        @return: SHA1.
        """
        if not self._sha1:
            self.calc_hashes()
        return self._sha1

    def get_sha256(self):
        """Get SHA256.
        @return: SHA256.
        """
        if not self._sha256:
            self.calc_hashes()
        return self._sha256

    def get_sha512(self):
        """
        Get SHA512.
        @return: SHA512.
        """
        if not self._sha512:
            self.calc_hashes()
        return self._sha512

    def get_ssdeep(self):
        """Get SSDEEP.
        @return: SSDEEP.
        """
        if not HAVE_PYDEEP:
            if not File.notified_pydeep:
                File.notified_pydeep = True
                log.warning("Unable to import pydeep (install with `pip install pydeep`)")
            return None

        try:
            return pydeep.hash_file(self.file_path)
        except Exception:
            return None

    def get_type(self):
        """Get MIME file type.
        @return: file type.
        """
        return magic.from_file(os.path.realpath(self.file_path))

    def get_content_type(self):
        """Get MIME content file type (example: image/jpeg).
        @return: file content type.
        """
        return magic.from_file(os.path.realpath(self.file_path), mime=True)

    def get_exported_functions(self):
        """Get the exported function names of this PE file."""
        filetype = self.get_type()
        if "MS-DOS" not in filetype and "PE32" not in self.get_type():
            return

        try:
            pe = pefile.PE(self.file_path)
            if not hasattr(pe, "DIRECTORY_ENTRY_EXPORT"):
                return

            for export in pe.DIRECTORY_ENTRY_EXPORT.symbols:
                if export.name:
                    yield export.name
        except Exception as e:
            log.warning("Error enumerating exported functions: %s", e)

    def get_imported_functions(self):
        """Get the imported functions of this PE file."""
        filetype = self.get_type()
        if "MS-DOS" not in filetype and "PE32" not in self.get_type():
            return

        try:
            pe = pefile.PE(self.file_path)
            if not hasattr(pe, "DIRECTORY_ENTRY_IMPORT"):
                return

            for imp in pe.DIRECTORY_ENTRY_IMPORT:
                for entry in imp.imports:
                    yield dict(dll=imp.dll,
                               name=entry.name,
                               ordinal=entry.ordinal,
                               hint=entry.hint,
                               address=entry.address)
        except Exception as e:
            log.warning("Error enumerating imported functions: %s", e)

    def get_apk_entry(self):
        """Get the entry point for this APK. The entry point is denoted by a
        package and main activity name."""
        filetype = self.get_type()
        if "Zip archive data" not in filetype and "Java archive data" not in filetype:
            return "", ""

        from androguard.core.bytecodes.apk import APK

        try:
            a = APK(self.file_path)
            if not a.is_valid_APK():
                return "", ""

            package = a.get_package()
            if not package:
                log.warning("Unable to find the main package, this analysis "
                            "will probably fail.")
                return "", ""

            main_activity = a.get_main_activity()
            if main_activity:
                log.debug("Picked package %s and main activity %s.",
                          package, main_activity)
                return package, main_activity

            activities = a.get_activities()
            for activity in activities:
                if "main" in activity or "start" in activity:
                    log.debug("Choosing package %s and main activity due to "
                              "its name %s.", package, activity)
                    return package, activity

            if activities and activities[0]:
                log.debug("Picked package %s and the first activity %s.",
                          package, activities[0])
                return package, activities[0]
        except Exception as e:
            log.warning("Error extracting package and main activity: %s.", e)

        return "", ""

    def get_yara(self, category="binaries", externals=None):
        """Get Yara signatures matches.
        @return: matched Yara signatures.
        """
        if not os.path.getsize(self.file_path):
            return []

        try:
            # TODO Once Yara obtains proper Unicode filepath support we can
            # remove this check. See also the following Github issue:
            # https://github.com/VirusTotal/yara-python/issues/48
            assert len(str(self.file_path)) == len(self.file_path)
        except (UnicodeEncodeError, AssertionError):
            log.warning(
                "Can't run Yara rules on %r as Unicode paths are currently "
                "not supported in combination with Yara!", self.file_path
            )
            return []

        results, rule = [], File.yara_rules[category]
        for match in rule.match(self.file_path, externals=externals):
            strings, offsets = set(), {}
            for _, key, value in match.strings:
                strings.add(base64.b64encode(value))
                offsets[key.lstrip("$")] = []

            strings = sorted(strings)
            for offset, key, value in match.strings:
                offsets[key.lstrip("$")].append(
                    (offset, strings.index(base64.b64encode(value)))
                )

            meta = {
                "description": "(no description)",
            }
            meta.update(match.meta)

            results.append({
                "name": match.rule,
                "meta": meta,
                "strings": strings,
                "offsets": offsets,
            })

        return results

    def mmap(self, fileno):
        if hasattr(mmap, "PROT_READ"):
            access = mmap.PROT_READ
        elif hasattr(mmap, "ACCESS_READ"):
            access = mmap.ACCESS_READ
        else:
            log.warning(
                "Regexing through a file is not supported on your OS!"
            )
            return

        return mmap.mmap(fileno, 0, access=access)

    def get_urls(self):
        """Extract all URLs embedded in this file through a simple regex."""
        if not os.path.getsize(self.file_path):
            return []

        # http://stackoverflow.com/a/454589
        urls, f = set(), open(self.file_path, "rb")
        for url in re.findall(URL_REGEX, self.mmap(f.fileno())):
            if not is_whitelisted_domain(url[1]):
                urls.add("".join(url))
        return list(urls)

    def get_keys(self):
        """Get any embedded plaintext public and/or private keys."""
        if not os.path.getsize(self.file_path):
            return []

        f = open(self.file_path, "rb")
        return list(set(re.findall(PUBPRIVKEY, self.mmap(f.fileno()))))

    def get_all(self):
        """Get all information available.
        @return: information dict.
        """
        infos = {}
        infos["name"] = self.get_name()
        infos["path"] = self.file_path
        infos["size"] = self.get_size()
        infos["crc32"] = self.get_crc32()
        infos["md5"] = self.get_md5()
        infos["sha1"] = self.get_sha1()
        infos["sha256"] = self.get_sha256()
        infos["sha512"] = self.get_sha512()
        infos["ssdeep"] = self.get_ssdeep()
        infos["type"] = self.get_type()
        infos["yara"] = self.get_yara()
        infos["urls"] = self.get_urls()
        return infos

class Archive(object):
    def __init__(self, filepath):
        self.filepath = filepath
        self.z = zipfile.ZipFile(filepath)

    def get_file(self, filename):
        filepath = tempfile.mktemp()
        shutil.copyfileobj(self.z.open(filename), open(filepath, "wb"))
        return File(filepath, temporary=True)

class Buffer(object):
    """A brief wrapper around string buffers for quick Yara rule matching."""

    def __init__(self, buffer):
        self.buffer = buffer

    def get_yara_quick(self, category, externals=None):
        results, rule = [], File.yara_rules[category]
        for match in rule.match(data=self.buffer, externals=externals):
            results.append(match.rule)
        return results

class YaraMatch(object):
    def __init__(self, match, category=None):
        self.name = match["name"]
        self.meta = match["meta"]
        self._decoded = {}
        self.offsets = match["offsets"]
        self.category = category

        self._strings = []
        for s in match["strings"]:
            self._strings.append(s.decode("base64"))

    def string(self, identifier, index=0):
        off, idx = self.offsets[identifier][index]
        return self._strings[idx]

    def strings(self, identifier):
        ret = []
        for off, idx in self.offsets[identifier]:
            ret.append(self._strings[idx])
        return ret

class ExtractedMatch(object):
    def __init__(self, match):
        self.category = match["category"]
        self.program = match.get("program")
        self.first_seen = match.get("first_seen")
        self.pid = match.get("pid")

        self.yara = []
        for ym in match["yara"]:
            self.yara.append(YaraMatch(ym))

        # Raw payload.
        self.raw = match.get("raw")
        self.info = match["info"]
